import datetime  # 用于datetime对象操作
import os.path  # 用于管理路径
import sys  # 用于在argvTo[0]中找到脚本名称
import backtrader as bt  # 引入backtrader框架
import pandas as pd

stk_num = 3  # 回测股票数目

# 创建策略
class SmaCross(bt.Strategy):
    # 可配置策略参数
    params = dict(
        pfast=5,  # 短期均线周期
        pslow=60,  # 长期均线周期
        poneplot=False,  # 是否打印到同一张图
        pstake=1000  # 单笔交易股票数目
    )

    def __init__(self):
        self.inds = dict()
        for i, d in enumerate(self.datas):      #  对于多支股票数据，需要使用enumerate拆包数据

            self.inds[d] = dict()
            self.inds[d]['sma1'] = bt.ind.SMA(d.close, period=self.p.pfast)  # 短期均线
            self.inds[d]['sma2'] = bt.ind.SMA(d.close, period=self.p.pslow)  # 长期均线
            self.inds[d]['cross'] = bt.ind.CrossOver(self.inds[d]['sma1'],
                                                     self.inds[d]['sma2'], plot=False)  # 交叉信号
            # # 跳过第一只股票data，第一只股票data作为主图数据
            # if i > 0:
            #     if self.p.poneplot:
            #         d.plotinfo.plotmaster = self.datas[0]

    def next(self):
        for i, d in enumerate(self.datas):    #  对于多支股票数据，需要使用enumerate拆包数据
            global dt, dn
            dt,dn= self.datetime.date(), d._name  # 获取时间及股票代码
            pos = self.getposition(d).size
            if not pos:  # 不在场内，则可以买入
                if self.inds[d]['cross'] > 0:  # 如果金叉
                    self.buy(data=d, size=self.p.pstake)  # 买买买
            elif self.inds[d]['cross'] < 0:  # 在场内，且死叉
                self.close(data=d)  # 卖卖卖

    def notify_order(self, order):
        if order.status in [order.Submitted, order.Accepted]:
            # 提交给代理或者由代理接收的买/卖订单 - 不做操作
            return
        # 检查订单是否执行完毕
        # 注意：如果没有足够资金，代理会拒绝订单
        if order.status in [order.Completed]:
            if order.isbuy():
                print("\033[35m日期：{}，股票代码：{}, BUY EXECUTED,Price:{:.2f},Cost:{:.2f},Comm:{:.2f}\033[0m"
                      .format(self.datas[0].datetime.date(0),dn,
                              order.executed.price, order.executed.value, order.executed.comm))
                # 卖
            else:
                print("\033[32m日期：{}，股票代码：{},SELL EXECUTED,Price:{:.2f},Cost:{:.2f},Comm:{:.2f}\033[0m"
                      .format(self.datas[0].datetime.date(0),dn,
                              order.executed.price, order.executed.value, order.executed.comm))

            self.bar_executed = len(self)
            # 如果指令取消/交易失败, 报告结果

    # def stop(self):
    #     print("\033[32m Ending value:{:.2f}\033[0m" .format(self.broker.getvalue()))

    def notify_trade(self, trade):
        if not trade.isclosed:
            return
        # 显示交易的毛利率和净利润
        print("\033[31m 交易日期:{} 股票代码：{} 毛利润:{:.2f}, 净利润:{:.2f}\033[0m"
              .format(self.datetime.date(),dn,trade.pnl, trade.pnlcomm))
        # self.log('OPERATION PROFIT, GROSS %.2f, NET %.2f' %
        #          (trade.pnl, trade.pnlcomm))

cerebro = bt.Cerebro()  # 创建cerebro
# 读入股票代码
stk_code_num=[]
stk_code_file =os.listdir('../data/tdx/day/')            #  数据相对于脚本文件的地址
for stk_code in stk_code_file:
    index1 = stk_code.rfind(".")
    newname_code = stk_code[:index1]
    # stk_path = '../data/tdx/day/'+stk_code
    # stk_data = pd.read_csv(stk_path, encoding='gbk')
    # stk_data['code']=newname_code
    stk_code_num.append(newname_code)
stk_pools = pd.DataFrame(stk_code_num)
if stk_num > stk_pools.shape[0]:
    print('股票数目不能大于%d' % stk_pools.shape[0])
    exit()
for i in range(stk_num):
    stk_code = stk_pools.iloc[i,0]
    # 读入数据
    datapath = '../data/tdx/day/' + stk_code + '.csv'
    # 创建价格数据
    data = bt.feeds.GenericCSVData(
        dataname=datapath,
        fromdate=datetime.datetime(2020, 12, 1),
        todate=datetime.datetime(2021, 12, 1),
        nullvalue=0.0,
        dtformat=('%Y-%m-%d'),
        datetime=0,
        open=1,
        high=2,
        low=3,
        close=4,
        volume=5,
        openinterest=-1
    )
    # 在Cerebro中添加股票数据
    cerebro.adddata(data, name=stk_code)
# 设置启动资金
cerebro.broker.setcash(100000.0)
# 设置交易单位大小
# cerebro.addsizer(bt.sizers.FixedSize, stake = 5000)
# 设置佣金为千分之一
cerebro.broker.setcommission(commission=0.001)
cerebro.addstrategy(SmaCross, poneplot=False)  # 添加策略
cerebro.run()  # 遍历所有数据
# 打印最后结果
print('Final Portfolio Value: %.2f' % cerebro.broker.getvalue())
# 绘图时间段调整位置
start_plot =datetime.datetime(2021, 9, 11)
end_plot =datetime.datetime(2021, 9, 25)
cerebro.plot(style="candle", start=start_plot, end=end_plot)  # 绘图
